"""Module that creates workspaces for training/evaling various agents."""

import wandb
import torch
import shutil
from os import makedirs
from loguru import logger
from tqdm import tqdm
import numpy as np
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Union, Optional
from scipy import stats
import cv2

from rewards import RewardFunctionConstructor
from custom_dmc_tasks.point_mass_maze import GOALS as point_mass_maze_goals

from agents.base import AbstractWorkspace
from agents.fb.agent import FB
from agents.fb.replay_buffer import FBReplayBuffer, OnlineFBReplayBuffer



from agents.cql.agent import CQL
from agents.base import OfflineReplayBuffer

from agents.cfb.agent import CFB
from agents.gciql.agent import GCIQL

from agents.sf.agent import SF
from agents.base import D4RLReplayBuffer

import time
class OfflineRLWorkspace(AbstractWorkspace):
    """
    Trains/evals/rollouts an offline RL agent given
    """

    def __init__(
        self,
        reward_constructor: RewardFunctionConstructor,
        learning_steps: int,
        model_dir: Path,
        eval_frequency: int,
        eval_rollouts: int,
        wandb_logging: bool,
        device: torch.device,
        z_inference_steps: Optional[int] = None,  # FB only
        train_std: Optional[float] = None,  # FB only
        eval_std: Optional[float] = None,  # FB only
    ):
        super().__init__(
            env=reward_constructor._env,
            reward_functions=reward_constructor.reward_functions,
        )

        self.eval_frequency = eval_frequency  # how frequently to eval
        self.eval_rollouts = eval_rollouts  # how many rollouts per eval step
        self.model_dir = model_dir
        self.learning_steps = learning_steps
        self.z_inference_steps = z_inference_steps
        self.train_std = train_std
        self.eval_std = eval_std
        self.observations_z = None
        self.rewards_z = None
        self.wandb_logging = wandb_logging
        self.domain_name = reward_constructor.domain_name
        self.device = device
    def test_crl(
        self,
        agent: Union[FB, CFB],
        tasks: List[str],
        agent_config: Dict,
        replay_buffer: Union[OfflineReplayBuffer, FBReplayBuffer],
    ) -> None:
        """
        Trains an offline RL algorithm on one task.
        """
        best_model_path = 'last.pickle' 
        
        agent.load(best_model_path)
        self.agent_config = agent_config

        a = replay_buffer.storage["observations"]
        
        if isinstance(agent, (FB)):
            if self.domain_name == "point_mass_maze":
                self.goal_states = {}
                for task, goal_state in point_mass_maze_goals.items():
                    self.goal_states[task] = torch.tensor(
                        goal_state, dtype=torch.float32, device=self.device
                    ).unsqueeze(0)
            else:
                (
                    self.observations_z,
                    self.rewards_z,
                ) = replay_buffer.sample_task_inference_transitions(
                    inference_steps=self.z_inference_steps
                )
        
        if agent_config['eval_method'] == 'lagrange':
            eval_metrics = self.lagrange_eval(agent=agent, tasks=tasks)
        elif agent_config['eval_method'] == 'constrained':
            eval_metrics = self.constrained_eval(agent=agent, tasks=tasks)
        if agent_config['eval_method'] == 'lagrange_hoffeding':
            eval_metrics = self.lagrange_hoffeding_eval(agent=agent, tasks=tasks)
       

    def train(
        self,
        agent: Union[CQL, FB, CFB, GCIQL],
        tasks: List[str],
        agent_config: Dict,
        replay_buffer: Union[OfflineReplayBuffer, FBReplayBuffer],
        seed: int,
    ) -> None:
        """
        Trains an offline RL algorithm on one task.
        """
        if self.wandb_logging:
            run = wandb.init(
                config=agent_config,
                tags=[agent.name],
                reinit=True,
                name='offline_{}_{}_seed_{}_eta1_grad_limit{}_FROMfbcz_check2'.format(self.domain_name, agent.name, seed, agent.grad_limit)
            )
            model_path = self.model_dir / run.name
            makedirs(str(model_path))
            

        else:
            date = datetime.today().strftime("Y-%m-%d-%H-%M-%S")
            model_path = self.model_dir / f"local-run-{date}"
            makedirs(str(model_path))

        logger.info(f"Training {agent.name}.")
        best_mean_task_reward = -np.inf
        best_model_path = None

        # sample set transitions for z inference
        if isinstance(agent, (FB, SF, GCIQL)):
            if self.domain_name == "point_mass_maze":
                self.goal_states = {}
                for task, goal_state in point_mass_maze_goals.items():
                    self.goal_states[task] = torch.tensor(
                        goal_state, dtype=torch.float32, device=self.device
                    ).unsqueeze(0)
            else:
                (
                    self.observations_z,
                    self.rewards_z,
                ) = replay_buffer.sample_task_inference_transitions(
                    inference_steps=self.z_inference_steps
                )

        for i in tqdm(range(self.learning_steps + 1)):

            batch = replay_buffer.sample(agent.batch_size)
            train_metrics = agent.update(batch=batch, step=i)

            eval_metrics = {}

            if i % self.eval_frequency == 0:
                agent._name = "last"
                agent.save(model_path)
                eval_metrics = self.eval(agent=agent, tasks=tasks)

                if eval_metrics["eval/task_reward_iqm"] > best_mean_task_reward:
                    agent._name = "best"
                    logger.info(
                        f"New max IQM task reward: {best_mean_task_reward:.3f} -> "
                        f"{eval_metrics['eval/task_reward_iqm']:.3f}."
                        f" Saving model."
                    )

                    # delete current best model
                    if best_model_path is not None:
                        best_model_path.unlink(missing_ok=True)
                    # save locally
                    best_model_path = agent.save(model_path)

                    best_mean_task_reward = eval_metrics["eval/task_reward_iqm"]
                all_model_save_path = agent.save_frequently(model_path, i)
                print("save frequently model_save_path", all_model_save_path)
                agent.train()

            metrics = {**train_metrics, **eval_metrics}

            if self.wandb_logging:
                run.log(metrics)

        if self.wandb_logging:
            # save to wandb_logging
            run.save(best_model_path.as_posix(), base_path=model_path.as_posix())
            run.finish()

        # delete local models
        #shutil.rmtree(model_path)

    def eval(
        self,
        agent: Union[CQL, FB, CFB],
        tasks: List[str],
    ) -> Dict[str, float]:
        """
        Performs eval rollouts.
        Args:
            agent: agent to evaluate
            tasks: tasks to evaluate on
        Returns:
            metrics: dict of metrics
        """

        if isinstance(agent, (FB, SF, GCIQL)):
            zs = {}
            if self.domain_name == "point_mass_maze":
                for task, goal_state in self.goal_states.items():
                    zs[task] = agent.infer_z(goal_state)
            else:
                for task, rewards in self.rewards_z.items():
                    zs[task] = agent.infer_z(self.observations_z, rewards)

            agent.std_dev_schedule = self.eval_std

        logger.info("Performing eval rollouts.")
        eval_rewards = {}
        agent.eval()
        for _ in tqdm(range(self.eval_rollouts)):

            for task in tasks:
                task_rewards = 0.0

                timestep = self.env.reset()
                print("zs", zs.keys())
                while not timestep.last():
                    if isinstance(agent, (FB, GCIQL, FB_CZ)):
                        action, _ = agent.act(
                            timestep.observation["observations"],
                            task=zs[task],
                            step=None,
                            sample=False,
                        )
                    

                    elif isinstance(agent, SF):
                        if self.domain_name != "point_mass_maze":
                            z = zs[task]
                        # calculate z at every step
                        else:
                            z = agent.infer_z_from_goal(
                                observation=timestep.observation["observations"],
                                goal_state=self.goal_states[task],
                            )
                        action, _ = agent.act(
                            timestep.observation["observations"],
                            task=z,
                            step=None,
                            sample=False,
                        )

                    else:
                        action = agent.act(
                            timestep.observation["observations"],
                            sample=False,
                            step=None,
                        )
                    timestep = self.env.step(action)
                    task_rewards += self.reward_functions[task](self.env.physics)

                if task not in eval_rewards:
                    eval_rewards[task] = []
                eval_rewards[task].append(task_rewards)

        # average over rollouts for metrics
        metrics = {}
        mean_task_performance = 0.0
        for task, rewards in eval_rewards.items():
            mean_task_reward = stats.trim_mean(rewards, 0.25)  # IQM
            metrics[f"eval/{task}/episode_reward_iqm"] = mean_task_reward
            mean_task_performance += mean_task_reward

        # log mean task performance
        metrics["eval/task_reward_iqm"] = mean_task_performance / len(tasks)

        if isinstance(agent, FB):
            agent.std_dev_schedule = self.train_std

        return metrics

    def eval_cz(
        self,
        agent: Union[CQL, FB, CFB],
        tasks: List[str],
        batch
    ) -> Dict[str, float]:
        """
        Performs eval rollouts.
        Args:
            agent: agent to evaluate
            tasks: tasks to evaluate on
        Returns:
            metrics: dict of metrics
        """

        if isinstance(agent, (FB, SF, GCIQL)):
            zs = {}
            if self.domain_name == "point_mass_maze":
                for task, goal_state in self.goal_states.items():
                    zs[task] = agent.infer_z(goal_state)
            else:
                for task, rewards in self.rewards_z.items():
                    zs[task] = agent.infer_z(self.observations_z, rewards)

            agent.std_dev_schedule = self.eval_std


            backward_input = batch.observations
            perm_dim = torch.randperm(backward_input.shape[-1])
            
            perm_dim = perm_dim[0]

            cost_backward_input = backward_input
            for d in [perm_dim]:
                col_np = backward_input[:, d].cpu().numpy()
                if np.random.rand() < 0.5:
                    large_flag = True
                    low_val = np.percentile(col_np, 75)
                    high_val = backward_input[:, d].cpu().max().item()
                    values = np.random.uniform(low=low_val, high=high_val, size=backward_input.shape[0])
                else:
                    large_flag = False
                    low_val = backward_input[:, d].cpu().min().item()
                    high_val = np.percentile(col_np, 25)
                    values = np.random.uniform(low=low_val, high=high_val, size=backward_input.shape[0])
                
                # Convert the generated values to a tensor with the same dtype and device as cost_backward_input
                cost_backward_input[:, d] = torch.tensor(values, dtype=backward_input.dtype, device=cost_backward_input.device)
            
            cz = agent.infer_z(cost_backward_input, 0*cost_backward_input[:,0] + 1)
            eta = 200*np.random.rand()

        logger.info("Performing eval rollouts.")
        eval_rewards_nocz = {}
        eval_rewards_cost_nocz = {}
        eval_rewards_withcz = {}
        eval_rewards_cost_withcz = {}
        agent.eval()
        for _ in tqdm(range(self.eval_rollouts)):

            for task in tasks:
                task_rewards = 0.0
                task_cost = 0.0

                timestep = self.env.reset()
                while not timestep.last():
                    if isinstance(agent, (FB, GCIQL, FB_CZ)):
                        action, _ = agent.act(
                            timestep.observation["observations"],
                            task=zs[task],
                            step=None,
                            sample=False,
                        )
                    
                    elif isinstance(agent, SF):
                        if self.domain_name != "point_mass_maze":
                            z = zs[task]
                        # calculate z at every step
                        else:
                            z = agent.infer_z_from_goal(
                                observation=timestep.observation["observations"],
                                goal_state=self.goal_states[task],
                            )
                        action, _ = agent.act(
                            timestep.observation["observations"],
                            task=z,
                            step=None,
                            sample=False,
                        )

                    else:
                        action = agent.act(
                            timestep.observation["observations"],
                            sample=False,
                            step=None,
                        )
                    timestep = self.env.step(action)
                    task_rewards += self.reward_functions[task](self.env.physics)

                    
                    if large_flag:
                        if timestep.observation["observations"][d] > low_val:
                            task_cost += 1
                    else:
                        if timestep.observation["observations"][d] < high_val:
                            task_cost += 1

                if task not in eval_rewards_nocz:
                    eval_rewards_nocz[task] = []
                    eval_rewards_cost_nocz[task] = []
                eval_rewards_nocz[task].append(task_rewards)
                eval_rewards_cost_nocz[task].append(task_cost)

        for _ in tqdm(range(self.eval_rollouts)):

            for task in tasks:
                task_rewards = 0.0
                task_cost = 0.0

                timestep = self.env.reset()
                while not timestep.last():
                    if isinstance(agent, (FB, GCIQL)):
                        action, _ = agent.act(
                            timestep.observation["observations"],
                            task=zs[task],
                            step=None,
                            sample=False,
                        )
                    

                    elif isinstance(agent, SF):
                        if self.domain_name != "point_mass_maze":
                            z = zs[task]
                        # calculate z at every step
                        else:
                            z = agent.infer_z_from_goal(
                                observation=timestep.observation["observations"],
                                goal_state=self.goal_states[task],
                            )
                        action, _ = agent.act(
                            timestep.observation["observations"],
                            task=z,
                            step=None,
                            sample=False,
                        )

                    else:
                        action = agent.act(
                            timestep.observation["observations"],
                            sample=False,
                            step=None,
                        )
                    timestep = self.env.step(action)
                    task_rewards += self.reward_functions[task](self.env.physics)

                    if large_flag:
                        if timestep.observation["observations"][d] > low_val:
                            task_cost += 1
                    else:
                        if timestep.observation["observations"][d] < high_val:
                            task_cost += 1

                if task not in eval_rewards_withcz:
                    eval_rewards_withcz[task] = []
                    eval_rewards_cost_withcz[task] = []
                eval_rewards_withcz[task].append(task_rewards)
                eval_rewards_cost_withcz[task].append(task_cost)

        # average over rollouts for metrics
        metrics = {}
        mean_task_performance_nocz = 0.0
        mean_task_performance_withcz = 0.0
        mean_overeta = 0.0
        for task, rewards in eval_rewards_nocz.items():
            mean_task_reward = stats.trim_mean(rewards, 0.25)  # IQM
            metrics[f"eval/{task}/episode_reward_nocz_iqm"] = mean_task_reward
            mean_task_performance_nocz += mean_task_reward
        for task, rewards in eval_rewards_cost_nocz.items():
            mean_task_reward = stats.trim_mean(rewards, 0.25)  # IQM
            metrics[f"eval/{task}/episode_reward_cost_nocz_iqm"] = mean_task_reward
            mean_task_performance_nocz -= mean_task_reward

        for task, rewards in eval_rewards_withcz.items():
            mean_task_reward = stats.trim_mean(rewards, 0.25)  # IQM
            metrics[f"eval/{task}/episode_reward_withcz_iqm"] = mean_task_reward
            mean_task_performance_withcz += mean_task_reward
        for task, rewards in eval_rewards_cost_withcz.items():
            mean_task_reward = stats.trim_mean(rewards, 0.25)  # IQM
            metrics[f"eval/{task}/episode_reward_cost_withcz_iqm"] = mean_task_reward
            
            mean_task_performance_withcz -= mean_task_reward

            if mean_task_reward > eta:
                mean_overeta += (mean_task_reward-eta)

        # log mean task performance
        metrics["eval/task_overall_nocz_iqm"] = mean_task_performance_nocz / len(tasks)
        metrics["eval/task_overall_withcz_iqm"] = mean_task_performance_withcz / len(tasks)
        metrics["eval/mean_overeta"] = mean_overeta / len(tasks)

        if isinstance(agent, FB):
            agent.std_dev_schedule = self.train_std

        return metrics
    def constrained_eval(
        self,
        agent: Union[FB, CFB], #Union[CQL, SAC, FB, CFB, CalFB]
        tasks: List[str],
    ) -> Dict[str, float]:
        """
        Performs eval rollouts.
        Args:
            agent: agent to evaluate
            tasks: tasks to evaluate on
        Returns:
            metrics: dict of metrics
        """
        video_flag = False
        if isinstance(agent, (FB)):
            zs = {}
            if self.domain_name == "point_mass_maze":
                for task, goal_state in self.goal_states.items():
                    zs[task] = agent.infer_z(goal_state)
            else:
                for task, rewards in self.rewards_z.items():
                    zs[task] = agent.infer_z(self.observations_z, rewards)

            agent.std_dev_schedule = self.eval_std

        
        specific_dimensions = [-1] 
        
        sample_range = self.sample_range
        print("constrained eval sample range set to:", sample_range)
        # eta = 100
        eta = self.agent_config['eta']
        print("constrained eval eta set to:", eta)

        )
        if self.domain_name == "point_mass_maze":
            if len(tasks) > 1:
                print("Only one task is allowed for point_mass_maze")
                raise ValueError("Only one task is allowed for point_mass_maze")
            observations_c_z = self.goal_states[tasks[0]].clone()
        else:
            observations_c_z = self.observations_z.clone()

        for dim in specific_dimensions:
            random_samples = np.random.uniform(sample_range[0], sample_range[1], size=observations_c_z.shape[0])
            observations_c_z[:, dim] = torch.tensor(random_samples, dtype=torch.float32).to(observations_c_z.device)
        c_zs = {}
        mask = observations_c_z[:,specific_dimensions[0]] > sample_range[0]
        
        
        if self.domain_name == "point_mass_maze":
            for task, goal_state in self.goal_states.items():
                c_zs[task] = agent.infer_z(observations_c_z, 10*(torch.tensor(random_samples, dtype=torch.float32).to(observations_c_z.device).unsqueeze(1)-sample_range[0]))
                # zs[task] = agent.infer_z(goal_state)
        else:
            for task, rewards in self.rewards_z.items():
                c_zs[task] = agent.infer_z(observations_c_z, (-rewards*0)+10*(torch.tensor(random_samples, dtype=torch.float32).to(observations_c_z.device).unsqueeze(1)-sample_range[0]))
        
        
        etas = eta*np.ones_like(zs[task][:1])

        logger.info("Performing eval rollouts.")
        eval_rewards = {}
        eval_cost = {}
        agent.eval()
        for _ in tqdm(range(self.eval_rollouts)):

            for task in tasks:
                task_rewards = 0.0
                task_cost = 0.0

                timestep = self.env.reset()
                
                if video_flag:
                    max_frame = 90

                    width = 480
                    height = 480
                    video = np.zeros((max_frame, height, 2 * width, 3), dtype=np.uint8)

                    output_file = "{}_{}.mp4".format(self.domain_name, task)
                    fps = 30  # Frames per second
                    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec for .mp4 format
                    out = cv2.VideoWriter(output_file, fourcc, fps, (2 * width, height))
                    frame_count = 0
                while not timestep.last():
                    if isinstance(agent, (FB, FB_CZ)):
                        action, _ = agent.act(
                            timestep.observation["observations"],
                            # task=zs[task],
                            task=zs[task]-c_zs[task],
                            step=None,
                            sample=False,
                        )
                   
                    else:
                        action = agent.act(
                            timestep.observation["observations"],
                            sample=False,
                            step=None,
                        )
                    timestep = self.env.step(action)
                    task_rewards += self.reward_functions[task](self.env.physics)

                    for dim in specific_dimensions:
                        if self.domain_name == "point_mass_maze":
                            if abs(timestep.observation["observations"][dim]) > sample_range[0]:
                                task_cost += 100*(abs(timestep.observation["observations"][dim]) - sample_range[0])
                        else:
                            if timestep.observation["observations"][dim] > sample_range[0]:
                                task_cost += 10*(timestep.observation["observations"][dim] - sample_range[0])
                            

                    if video_flag:
                        frame = np.hstack([self.env.physics.render(height, width, camera_id=0),
                                            self.env.physics.render(height, width, camera_id=1)])
                        video[frame_count] = frame

                        # Write frame to video
                        out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))  # Convert RGB to BGR
                        frame_count += 1

                if task not in eval_rewards:
                    eval_rewards[task] = []
                    eval_cost[task] = []
                eval_rewards[task].append(task_rewards)
                eval_cost[task].append(task_cost)
                z = torch.as_tensor(zs[task], dtype=torch.float32, device=observations_c_z.device).unsqueeze(0)
                c_z = torch.as_tensor(c_zs[task], dtype=torch.float32, device=observations_c_z.device).unsqueeze(0)
                eta = torch.as_tensor([eta], dtype=torch.float32, device=observations_c_z.device).unsqueeze(0)
                
            if video_flag:
                out.release()
                print(f"Video saved as {output_file}")
        metrics = {}
        mean_task_performance = 0.0
        mean_task_cost_value = 0.0
        for task, rewards in eval_rewards.items():
            mean_task_reward = stats.trim_mean(rewards, 0.25)  # IQM 
            metrics[f"eval/{task}/episode_reward_iqm"] = mean_task_reward
            print(f"eval/{task}/episode_reward_iqm", mean_task_reward)
            mean_task_performance += mean_task_reward
        for task, costs in eval_cost.items():
            mean_task_cost = stats.trim_mean(costs, 0.25)
            metrics[f"eval/{task}/episode_cost_iqm"] = mean_task_cost
            print(f"eval/{task}/episode_cost_iqm", mean_task_cost)
            mean_task_cost_value += mean_task_cost
        
        
        # log mean task performance
        metrics["eval/task_reward_iqm"] = mean_task_performance / len(tasks)
        metrics["eval/task_cost_iqm"] = mean_task_cost_value / len(tasks)
        print("eval/task_reward_iqm", mean_task_performance / len(tasks))
        print("eval/task_cost_iqm", mean_task_cost_value / len(tasks))

        if isinstance(agent, (FB, FB_CZ)):
            print("Naive constrained FB")
            eval_rewards = {}
            eval_cost = {}
            for _ in tqdm(range(self.eval_rollouts)):
                for task in tasks:
                    task_rewards = 0.0
                    task_cost = 0.0

                    timestep = self.env.reset()
                    
                    if video_flag:
                        max_frame = 90

                        width = 480
                        height = 480
                        video = np.zeros((max_frame, height, 2 * width, 3), dtype=np.uint8)

                        output_file = "{}_{}.mp4".format(self.domain_name, task)
                        fps = 30  # Frames per second
                        fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec for .mp4 format
                        out = cv2.VideoWriter(output_file, fourcc, fps, (2 * width, height))
                        frame_count = 0
                    while not timestep.last():
                        if isinstance(agent, (FB, FB_CZ)):
                            action, _ = agent.act(
                                timestep.observation["observations"],
                                # task=zs[task]-c_zs[task],
                                task=zs[task],
                                step=None,
                                sample=False,
                            )
                        
                        timestep = self.env.step(action)
                        task_rewards += self.reward_functions[task](self.env.physics)

                        
                        for dim in specific_dimensions:
                            if self.domain_name == "point_mass_maze":
                                if abs(timestep.observation["observations"][dim]) > sample_range[0]:
                                    task_cost += 100*(abs(timestep.observation["observations"][dim]) - sample_range[0])
                            else:
                                if timestep.observation["observations"][dim] > sample_range[0]:
                                    task_cost += 10*(timestep.observation["observations"][dim] - sample_range[0])    
                        if video_flag:
                            frame = np.hstack([self.env.physics.render(height, width, camera_id=0),
                                                self.env.physics.render(height, width, camera_id=1)])
                            video[frame_count] = frame

                            # Write frame to video
                            out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))  # Convert RGB to BGR
                            frame_count += 1

                    if task not in eval_rewards:
                        eval_rewards[task] = []
                        eval_cost[task] = []
                    eval_rewards[task].append(task_rewards)
                    eval_cost[task].append(task_cost)

        # average over rollouts for metrics
        metrics = {}
        mean_task_performance = 0.0
        mean_task_cost_value = 0.0
        for task, rewards in eval_rewards.items():
            mean_task_reward = stats.trim_mean(rewards, 0.25)  # IQM 
            metrics[f"eval/{task}/episode_reward_iqm"] = mean_task_reward
            print(f"eval/{task}/episode_reward_iqm", mean_task_reward)
            mean_task_performance += mean_task_reward
        for task, costs in eval_cost.items():
            mean_task_cost = stats.trim_mean(costs, 0.25)
            metrics[f"eval/{task}/episode_cost_iqm"] = mean_task_cost
            print(f"eval/{task}/episode_cost_iqm", mean_task_cost)
            mean_task_cost_value += mean_task_cost

        # log mean task performance
        metrics["eval/task_reward_iqm"] = mean_task_performance / len(tasks)
        metrics["eval/task_cost_iqm"] = mean_task_cost_value / len(tasks)
        print("eval/task_reward_iqm", mean_task_performance / len(tasks))
        print("eval/task_cost_iqm", mean_task_cost_value / len(tasks))

        if isinstance(agent, (FB)):
            agent.std_dev_schedule = self.train_std

        return metrics


    def lagrange_eval(
        self,
        agent: Union[FB, CFB], #Union[CQL, SAC, FB, CFB, CalFB]
        tasks: List[str],
    ) -> Dict[str, float]:
        """
        Performs eval rollouts.
        Args:
            agent: agent to evaluate
            tasks: tasks to evaluate on
        Returns:
            metrics: dict of metrics
        """
        video_flag = False
        if isinstance(agent, (FB)):
            zs = {}
            if self.domain_name == "point_mass_maze":
                for task, goal_state in self.goal_states.items():
                    zs[task] = agent.infer_z(goal_state)
            else:
                for task, rewards in self.rewards_z.items():
                    zs[task] = agent.infer_z(self.observations_z, rewards)

            agent.std_dev_schedule = self.eval_std

        specific_dimensions = [-1]  
        sample_range = self.sample_range
        print("lagrange eval sample range set to:", sample_range)

        eta = self.agent_config['eta']
        print("lagrange eval eta set to:", eta)
        budget_eta = eta  # Initialize budget η'

        budget_min = eta * (-3.5) 
        budget_max = eta * 1 
        budget_step = 5.0
        budget_tolerance = 5.0 
        if self.domain_name == "point_mass_maze":
            if len(tasks) > 1:
                print("Only one task is allowed for point_mass_maze")
                raise ValueError("Only one task is allowed for point_mass_maze")
            observations_c_z = self.goal_states[tasks[0]].clone()
        else:
            observations_c_z = self.observations_z.clone()
            
        for dim in specific_dimensions:
            random_samples = np.random.uniform(sample_range[0], sample_range[1], size=observations_c_z.shape[0])
            observations_c_z[:, dim] = torch.tensor(random_samples, dtype=torch.float32).to(observations_c_z.device)
        c_zs = {}
        mask = observations_c_z[:,specific_dimensions[0]] > sample_range[0]
        
        if self.domain_name == "point_mass_maze":
            for task, goal_state in self.goal_states.items():
                c_zs[task] = agent.infer_z(observations_c_z, 10*(torch.tensor(random_samples, dtype=torch.float32).to(observations_c_z.device).unsqueeze(1)-sample_range[0]))
        else:
            for task, rewards in self.rewards_z.items():
                c_zs[task] = agent.infer_z(observations_c_z, (-rewards*0)+10*(torch.tensor(random_samples, dtype=torch.float32).to(observations_c_z.device).unsqueeze(1)-sample_range[0]))

        etas = eta*np.ones_like(zs[task][:1])

        logger.info("Performing constrained eval rollouts with adaptive budget.")
        eval_rewards = {}
        eval_cost = {}
        
        agent.eval()
        
        total_samples = 0
        for outer_iter in range(35):  
            mean_task_cost = 0.0
            mean_task_rewards = 0.0
            
            for task in tasks:
                # Initialize environment first to get initial timestep
                timestep = self.env.reset()
                
                # Lagrange multiplier search range
                lagrange_min = 0
                lagrange_max = 10
                lagrange_multiplier = np.random.uniform(lagrange_min, lagrange_max)
                
                
                # Binary search to find suitable λ such that F*Z_c = η'
                for search_iter in range(50):  # Limit search iterations
                    Z_lambda_c = zs[task] - lagrange_multiplier*c_zs[task]
                    
                    # Get sample observation and action
                    sample_obs = timestep.observation["observations"]
                    action, _ = agent.act(
                        sample_obs,
                        task=Z_lambda_c,
                        step=None,
                        sample=False,
                    )
                    # Convert to tensors for forward pass
                    observation = torch.as_tensor(sample_obs, dtype=torch.float32, device=self.device).unsqueeze(0)
                    action = torch.as_tensor(action, dtype=torch.float32, device=self.device).unsqueeze(0)
                    Z_lambda_c_tensor = torch.as_tensor(Z_lambda_c, dtype=torch.float32, device=self.device).unsqueeze(0)
                    
                    # Calculate estimated cost value (F*Z_c)
                    F1, F2 = agent.FB.forward_representation(
                        observation=observation, z=Z_lambda_c_tensor, action=action
                    )
                    
                    c_z = torch.as_tensor(c_zs[task], dtype=torch.float32, device=self.device).unsqueeze(0)
                    Q1 = torch.einsum("sd, sd -> s", F1, c_z)
                    Q2 = torch.einsum("sd, sd -> s", F2, c_z)
                    Q = torch.max(Q1, Q2)
                    Q = Q.squeeze(0)
                    
                    # Inner loop: adjust λ based on comparison with budget η'
                    if abs(Q - budget_eta) < 2:
                        print(f"Converged! Budget η': {budget_eta}, Actual cost: {mean_task_cost}, ETA: {eta}")
                        break
                    elif Q > budget_eta:  # If F*Z_c > η': increase λ
                        lagrange_min = lagrange_multiplier
                        lagrange_multiplier = (lagrange_multiplier + lagrange_max) / 2
                    elif Q < budget_eta:  # If F*Z_c < η': decrease λ
                        lagrange_max = lagrange_multiplier
                        lagrange_multiplier = (lagrange_multiplier + lagrange_min) / 2
                
                # Use the found λ to test in actual environment
                task_rewards = 0.0
                task_cost = 0.0
                timestep = self.env.reset()
                
                while not timestep.last():
                    Z_lambda_c = zs[task] - lagrange_multiplier*c_zs[task]
                    action, _ = agent.act(
                        timestep.observation["observations"],
                        task=Z_lambda_c,
                        step=None,
                        sample=False,
                    )
                    
                    timestep = self.env.step(action)
                    total_samples += 1
                    task_rewards += self.reward_functions[task](self.env.physics)
                    
                    
                    for dim in specific_dimensions:
                        if self.domain_name == "point_mass_maze":
                            if abs(timestep.observation["observations"][dim]) > sample_range[0]:
                                task_cost += 100*(abs(timestep.observation["observations"][dim]) - sample_range[0])
                        else:
                            if timestep.observation["observations"][dim] > sample_range[0]:
                                task_cost += 10*(timestep.observation["observations"][dim] - sample_range[0])

                if task not in eval_rewards:
                    eval_rewards[task] = []
                    eval_cost[task] = []
                eval_rewards[task].append(task_rewards)
                eval_cost[task].append(task_cost)
                
                mean_task_cost += task_cost
                mean_task_rewards += task_rewards
                
                z = torch.as_tensor(zs[task], dtype=torch.float32, device=observations_c_z.device).unsqueeze(0)
                c_z = torch.as_tensor(c_zs[task], dtype=torch.float32, device=observations_c_z.device).unsqueeze(0)
            
            # Calculate average cost and rewards
            mean_task_cost /= len(tasks)
            mean_task_rewards /= len(tasks)
            print(f"Iter {outer_iter}, Budget η': {budget_eta}, Actual cost: {mean_task_cost}, Reward: {mean_task_rewards}")
            
            # Outer loop: adjust budget η' based on actual cost compared to target η
            if abs(mean_task_cost - eta) < budget_tolerance and mean_task_cost < eta:
                print(f"Converged! Budget η': {budget_eta}, Actual cost: {mean_task_cost}, Total samples: {total_samples}")
                break
            elif mean_task_cost > eta:
                budget_max = budget_eta
                budget_eta = (budget_eta + budget_min)/2
                budget_eta = max(budget_eta, budget_min)
                print(f"Cost too high, decreasing budget to {budget_eta}, Total samples: {total_samples}")
            else:
                budget_min = budget_eta
                budget_eta += budget_step
                budget_eta = min(budget_eta, budget_max)
                print(f"Cost too low, increasing budget to {budget_eta}, Total samples: {total_samples}")
        
        # Calculate evaluation metrics (keeping original metrics calculation)
        metrics = {}
        mean_task_performance = 0.0
        mean_task_cost_value = 0.0
        for task, rewards in eval_rewards.items():
            mean_task_reward = stats.trim_mean(rewards, 0.25)  # IQM 
            metrics[f"eval/{task}/episode_reward_iqm"] = mean_task_reward
            print(f"eval/{task}/episode_reward_iqm", mean_task_reward)
            mean_task_performance += mean_task_reward
        for task, costs in eval_cost.items():
            mean_task_cost = stats.trim_mean(costs, 0.25)
            metrics[f"eval/{task}/episode_cost_iqm"] = mean_task_cost
            print(f"eval/{task}/episode_cost_iqm", mean_task_cost)
            mean_task_cost_value += mean_task_cost
        

        # Add new metrics for the adaptive budget approach
        metrics["eval/task_reward_iqm"] = mean_task_performance / len(tasks)
        metrics["eval/task_cost_iqm"] = mean_task_cost_value / len(tasks)
        metrics["eval/final_budget_eta"] = budget_eta
        
        print("Evaluation results:")
        print(f"Evaluation reward: {mean_task_performance / len(tasks)}")
        print(f"Evaluation cost: {mean_task_cost_value / len(tasks)}")
        print(f"Final budget η': {budget_eta}")
        
        if isinstance(agent, (FB)):
            agent.std_dev_schedule = self.train_std

        return metrics
    
    def lagrange_hoffeding_eval(
        self,
        agent: Union[FB, CFB], #Union[CQL, SAC, FB, CFB, CalFB]
        tasks: List[str],
    ) -> Dict[str, float]:
        """
        Performs eval rollouts.
        Args:
            agent: agent to evaluate
            tasks: tasks to evaluate on
        Returns:
            metrics: dict of metrics
        """
        video_flag = False
        if isinstance(agent, (FB)):
            zs = {}
            if self.domain_name == "point_mass_maze":
                for task, goal_state in self.goal_states.items():
                    zs[task] = agent.infer_z(goal_state)
            else:
                for task, rewards in self.rewards_z.items():
                    zs[task] = agent.infer_z(self.observations_z, rewards)

            agent.std_dev_schedule = self.eval_std

        # Define the specific dimensions and the range for random samples
        specific_dimensions = [-1]  
        sample_range = self.sample_range
        print("lagrange eval sample range set to:", sample_range)

        eta = self.agent_config['eta']
        print("lagrange eval eta set to:", eta)
        budget_eta = eta  # Initialize budget η'

        
        budget_min = eta * (-3.5) 
        budget_max = eta * 1 
        budget_step = 5.0
        budget_tolerance = 5.0  # Allowable error between actual cost and target

        if self.domain_name == "point_mass_maze":
            if len(tasks) > 1:
                print("Only one task is allowed for point_mass_maze")
                raise ValueError("Only one task is allowed for point_mass_maze")
            observations_c_z = self.goal_states[tasks[0]].clone()
        else:
            # Create a copy of observations_z and modify the specified dimensions
            observations_c_z = self.observations_z.clone()
        
            
        for dim in specific_dimensions:
            random_samples = np.random.uniform(sample_range[0], sample_range[1], size=observations_c_z.shape[0])
            observations_c_z[:, dim] = torch.tensor(random_samples, dtype=torch.float32).to(observations_c_z.device)
        c_zs = {}
        mask = observations_c_z[:,specific_dimensions[0]] > sample_range[0]
        
        if self.domain_name == "point_mass_maze":
            for task, goal_state in self.goal_states.items():
                c_zs[task] = agent.infer_z(observations_c_z, 10*(torch.tensor(random_samples, dtype=torch.float32).to(observations_c_z.device).unsqueeze(1)-sample_range[0]))

        else:
            for task, rewards in self.rewards_z.items():
                c_zs[task] = agent.infer_z(observations_c_z, (-rewards*0)+10*(torch.tensor(random_samples, dtype=torch.float32).to(observations_c_z.device).unsqueeze(1)-sample_range[0]))

        print("c_zs", c_zs)
        print("zs", zs)

        etas = eta*np.ones_like(zs[task][:1])

        logger.info("Performing constrained eval rollouts with adaptive budget.")
        
        
        agent.eval()
        
        total_samples = 0
        for outer_iter in range(35):  # Maximum 10 budget adjustment attempts
            mean_task_cost = 0.0
            mean_task_rewards = 0.0

            eval_rewards = {}
            eval_cost = {}

            eval_mean_rewards = {}
            eval_mean_cost = {}
            eval_max_rewards = {}
            eval_max_cost = {}

            budget_eta = eta  # Initialize budget η'
            budget_min = eta * (-4)
            budget_max = eta * 1
            # budget_min = eta * (-3.5) # point_mass_maze
            # budget_max = eta * 1 # point_mass_maze
            budget_step = 5.0
            budget_tolerance = 5.0  # Allowable error between actual cost and target
            
            for task in tasks:
                # Initialize environment first to get initial timestep
                timestep = self.env.reset()
                
                # Lagrange multiplier search range
                lagrange_min = 0
                lagrange_max = 10
                
                # Randomly initialize λ
                lagrange_multiplier = np.random.uniform(lagrange_min, lagrange_max)
                
                for i in range(self.agent_config['hoffeding_sample'] ): #
                    task_rewards = 0.0
                    task_cost = 0.0
                    timestep = self.env.reset()
                    
                    # Binary search to find suitable λ such that F*Z_c = η'
                    for search_iter in range(50):  # Limit search iterations
                        Z_lambda_c = zs[task] - lagrange_multiplier*c_zs[task]
                        
                        # Get sample observation and action
                        sample_obs = timestep.observation["observations"]
                        action, _ = agent.act(
                            sample_obs,
                            task=Z_lambda_c,
                            step=None,
                            sample=False,
                        )
                        # Convert to tensors for forward pass
                        observation = torch.as_tensor(sample_obs, dtype=torch.float32, device=self.device).unsqueeze(0)
                        action = torch.as_tensor(action, dtype=torch.float32, device=self.device).unsqueeze(0)
                        Z_lambda_c_tensor = torch.as_tensor(Z_lambda_c, dtype=torch.float32, device=self.device).unsqueeze(0)
                        
                        # Calculate estimated cost value (F*Z_c)
                        F1, F2 = agent.FB.forward_representation(
                            observation=observation, z=Z_lambda_c_tensor, action=action
                        )
                        
                        c_z = torch.as_tensor(c_zs[task], dtype=torch.float32, device=self.device).unsqueeze(0)
                        Q1 = torch.einsum("sd, sd -> s", F1, c_z)
                        Q2 = torch.einsum("sd, sd -> s", F2, c_z)
                        Q = torch.max(Q1, Q2)
                        Q = Q.squeeze(0)
                        
                        # Inner loop: adjust λ based on comparison with budget η'
                        if abs(Q - budget_eta) < 2:
                            # print("Converged! Budget η': {budget_eta}, Actual cost: {mean_task_cost}")
                            print(f"Converged! Budget η': {budget_eta}, Actual cost: {mean_task_cost}, ETA: {eta}")
                            break
                        elif Q > budget_eta:  # If F*Z_c > η': increase λ
                            lagrange_min = lagrange_multiplier
                            lagrange_multiplier = (lagrange_multiplier + lagrange_max) / 2
                        elif Q < budget_eta:  # If F*Z_c < η': decrease λ
                            lagrange_max = lagrange_multiplier
                            lagrange_multiplier = (lagrange_multiplier + lagrange_min) / 2
                

                    while not timestep.last():
                        Z_lambda_c = zs[task] - lagrange_multiplier*c_zs[task]
                        action, _ = agent.act(
                            timestep.observation["observations"],
                            task=Z_lambda_c,
                            step=None,
                            sample=False,
                        )
                        
                        timestep = self.env.step(action)
                        total_samples += 1
                        task_rewards += self.reward_functions[task](self.env.physics)
                        
                        
                        for dim in specific_dimensions:
                            if self.domain_name == "point_mass_maze":
                                if abs(timestep.observation["observations"][dim]) > sample_range[0]:
                                    task_cost += 100*(abs(timestep.observation["observations"][dim]) - sample_range[0])
                            else:
                                if timestep.observation["observations"][dim] > sample_range[0]:
                                    task_cost += 10*(timestep.observation["observations"][dim] - sample_range[0])
                    

                    if task not in eval_rewards:
                        eval_rewards[task] = []
                        eval_cost[task] = []
                    eval_rewards[task].append(task_rewards)
                    eval_cost[task].append(task_cost)
                    
                    mean_task_cost += task_cost
                    mean_task_rewards += task_rewards
                    
                    z = torch.as_tensor(zs[task], dtype=torch.float32, device=observations_c_z.device).unsqueeze(0)
                    c_z = torch.as_tensor(c_zs[task], dtype=torch.float32, device=observations_c_z.device).unsqueeze(0)
            
            for task in tasks:
                eval_mean_rewards[task] = sum(eval_rewards[task]) / len(eval_rewards[task])
                eval_mean_cost[task] = sum(eval_cost[task]) / len(eval_cost[task])
                eval_max_rewards[task] = max(eval_rewards[task])
                eval_max_cost[task] = max(eval_cost[task])budget_tolerance = np.sqrt(eval_max_cost[task]* eval_max_cost[task]*np.log(2/self.agent_config['hoffeding_delta'])/(2*self.agent_config['hoffeding_sample'] )) #  agent_config['hoffeding_sample'] 
                print(f"Budget tolerance: {budget_tolerance}")
                print('eval_cost', eval_cost)

            # Calculate average cost and rewards
            mean_task_cost /= len(tasks)
            mean_task_rewards /= len(tasks)
            mean_task_cost /= self.agent_config['hoffeding_sample'] 
            mean_task_rewards /= self.agent_config['hoffeding_sample'] 
            # 
            print(f"Iter {outer_iter}, Budget η': {budget_eta}, Actual cost: {mean_task_cost}, Reward: {mean_task_rewards}")
            
           
            if abs(mean_task_cost - eta) < budget_tolerance and mean_task_cost < eta:
                print(f"Converged! Budget η': {budget_eta}, Actual cost: {mean_task_cost}, Total samples: {total_samples}")
                break
            elif mean_task_cost > eta:
                if outer_iter < 3:
                    budget_max = budget_eta
                    budget_eta = (budget_eta + budget_min)/2
                    budget_eta = max(budget_eta, budget_min)
                else:
                    budget_eta -= budget_step
                
                print(f"Cost too high, decreasing budget to {budget_eta}, Total samples: {total_samples}")
                # print(f"Cost too high, decreasing budget to {budget_eta}")
            else:
                

                if outer_iter < 3:
                    budget_min = budget_eta
                    budget_eta = (budget_eta + budget_max)/2
                    budget_eta = min(budget_eta, budget_max)
                else:
                    budget_eta += budget_step


                # print(f"Cost too low, increasing budget to {budget_eta}")
                print(f"Cost too low, increasing budget to {budget_eta}, Total samples: {total_samples}")
        
        # Calculate evaluation metrics (keeping original metrics calculation)
        metrics = {}
        mean_task_performance = 0.0
        mean_task_cost_value = 0.0
        for task, rewards in eval_rewards.items():
            mean_task_reward = stats.trim_mean(rewards, 0.25)  # IQM 
            metrics[f"eval/{task}/episode_reward_iqm"] = mean_task_reward
            print(f"eval/{task}/episode_reward_iqm", mean_task_reward)
            mean_task_performance += mean_task_reward
        for task, costs in eval_cost.items():
            mean_task_cost = stats.trim_mean(costs, 0.25)
            metrics[f"eval/{task}/episode_cost_iqm"] = mean_task_cost
            print(f"eval/{task}/episode_cost_iqm", mean_task_cost)
            mean_task_cost_value += mean_task_cost
        
        
        metrics["eval/task_reward_iqm"] = mean_task_performance / len(tasks)
        metrics["eval/task_cost_iqm"] = mean_task_cost_value / len(tasks)
        metrics["eval/final_budget_eta"] = budget_eta
        
        print("Evaluation results:")
        print(f"Evaluation reward: {mean_task_performance / len(tasks)}")
        print(f"Evaluation cost: {mean_task_cost_value / len(tasks)}")
        print(f"Final budget η': {budget_eta}")
        
        if isinstance(agent, (FB)):
            agent.std_dev_schedule = self.train_std

        return metrics


class FinetuningWorkspace(OfflineRLWorkspace):
    """
    Finetunes FB or CFB on one task.
    """

    def __init__(
        self,
        reward_constructor: RewardFunctionConstructor,
        learning_steps: int,
        model_dir: Path,
        eval_frequency: int,
        eval_rollouts: int,
        wandb_logging: bool,
        online: bool,
        #critic_tuning: bool,
        device: torch.device,
        z_inference_steps: Optional[int] = None,  # FB only
        train_std: Optional[float] = None,  # FB only
        eval_std: Optional[float] = None,  # FB only
    ):
        super().__init__(
            reward_constructor=reward_constructor,
            learning_steps=learning_steps,
            model_dir=model_dir,
            eval_frequency=eval_frequency,
            eval_rollouts=eval_rollouts,
            wandb_logging=wandb_logging,
            device=device,
            z_inference_steps=z_inference_steps,
            train_std=train_std,
            eval_std=eval_std,
        )

        self.online = online
        #self.critic_tuning = critic_tuning

    def train(
        self,
        agent: Union[FB, CFB],
        tasks: List[str],
        agent_config: Dict,
        replay_buffer: Union[FBReplayBuffer, OnlineFBReplayBuffer],
        seed: int,
        episodes: int = None,
    ) -> None:

        #assert len(tasks) == 1

        if self.online:
            self.tune_online(
                agent=agent,
                task=tasks,
                agent_config=agent_config,
                replay_buffer=replay_buffer,
                episodes=episodes,
                seed=seed,
            )

        else:
            self.tune_offline(
                agent=agent,
                task=tasks,
                agent_config=agent_config,
                replay_buffer=replay_buffer,
            )

    def tune_offline(
        self,
        agent: Union[FB, CFB],
        task: List[str],
        agent_config: Dict,
        replay_buffer: FBReplayBuffer,
    ) -> None:
        """
        Finetunes FB or CFB on one task offline, without online interaction.
        Args:
            agent: agent to finetune
            task: task to finetune on
            agent_config: agent config
            replay_buffer: replay buffer for z sampling
        """

        if self.wandb_logging:
            run = wandb.init(
                config=agent_config,
                tags=[agent.name, "finetuning"],
                reinit=True,
            )

        else:
            date = datetime.today().strftime("Y-%m-%d-%H-%M-%S")
            model_path = self.model_dir / f"local-run-{date}"
            makedirs(str(model_path))

        # get observations and rewards for task inference
        if self.domain_name == "point_mass_maze":
            self.goal_states = {}

            goal_state = point_mass_maze_goals[task[0]]
            self.goal_states[task[0]] = torch.tensor(
                goal_state, dtype=torch.float32, device=self.device
            ).unsqueeze(0)
        else:
            (
                self.observations_z,
                self.rewards_z,
            ) = replay_buffer.sample_task_inference_transitions(
                inference_steps=self.z_inference_steps
            )

        best_mean_task_reward = -np.inf

        # get initial eval metrics
        logger.info("Getting init performance.")
        eval_metrics = self.eval(agent=agent, tasks=task)
        init_performance = eval_metrics["eval/task_reward_iqm"]

        logger.info(f"Finetuning {agent.name} on {self.domain_name}-{task[0]}.")

        for i in tqdm(range(self.learning_steps + 1)):

            batch = replay_buffer.sample(agent.batch_size)

            # infer z for task
            if self.domain_name == "point_mass_maze":
                z = agent.infer_z(self.goal_states[task[0]])
            else:
                z = agent.infer_z(self.observations_z, self.rewards_z[task[0]])

            z_batch = torch.tile(
                torch.as_tensor(z, dtype=torch.float32, device=self.device),
                (agent.batch_size, 1),
            )  # repeat z for batch size

            if self.critic_tuning:
                fb_metrics = agent.update_fb(
                    observations=batch.observations,
                    next_observations=batch.next_observations,
                    actions=batch.actions,
                    discounts=batch.discounts,
                    zs=z_batch,
                    step=i,
                )
                actor_metrics = agent.update_actor(
                    observation=batch.observations, z=z_batch, step=i
                )

                agent.soft_update_params(
                    network=agent.FB.forward_representation,
                    target_network=agent.FB.forward_representation_target,
                    tau=agent._tau,  # pylint: disable=protected-access
                )
                agent.soft_update_params(
                    network=agent.FB.backward_representation,
                    target_network=agent.FB.backward_representation_target,
                    tau=agent._tau,  # pylint: disable=protected-access
                )
                if agent.name in ("VCalFB", "MCalFB"):
                    agent.soft_update_params(
                        network=agent.FB.forward_mu,
                        target_network=agent.FB.forward_mu_target,
                        tau=agent._tau,  # pylint: disable=protected-access
                    )

                train_metrics = {**fb_metrics, **actor_metrics}

            else:
                train_metrics = agent.update_actor(
                    observation=batch.observations, z=z_batch, step=i
                )

            eval_metrics = {}

            if i % self.eval_frequency == 0:
                eval_metrics = self.eval(agent=agent, tasks=task)
                eval_metrics["eval/init_performance"] = init_performance

                if eval_metrics["eval/task_reward_iqm"] > best_mean_task_reward:
                    logger.info(
                        f"Finetuned performance:"
                        f"{eval_metrics['eval/task_reward_iqm']:.1f} |"
                        f" Init performance:"
                        f"{eval_metrics['eval/init_performance']:.1f}"
                    )

                    best_mean_task_reward = eval_metrics["eval/task_reward_iqm"]

                agent.train()

            metrics = {**train_metrics, **eval_metrics}

            if self.wandb_logging:
                run.log(metrics)

        if self.wandb_logging:
            # save to wandb_logging
            run.finish()

    def tune_online(
        self,
        agent: Union[FB, CFB],
        task: List[str],
        agent_config: Dict,
        replay_buffer: OnlineFBReplayBuffer,
        episodes: int,
        seed: int,
    ) -> None:
        """
        Finetunes FB or CFB on one task using online data.
        Args:
            agent: agent to finetune
            task: task to finetune on
            agent_config: agent config
            replay_buffer: replay buffer for z sampling
            episodes: number of episodes to finetune for
        """
        if 'crl' in agent.name:
            crl_flag = True
        else:
            crl_flag = False

        if self.wandb_logging:
            run = wandb.init(
                config=agent_config,
                tags=[agent.name, "finetuning"],
                reinit=True,
                name='online+offline_{}_{}_seed_{}_offlineratio{}_trainstd{}_eta1_grad_limi{}_FROMfbcz_check2'.format(self.domain_name, agent.name, seed, replay_buffer.offline_data_ratio, agent.std_dev_schedule, agent.grad_limit)
            )
            model_path = self.model_dir / run.name
            makedirs(str(model_path))

        else:
            date = datetime.today().strftime("Y-%m-%d-%H-%M-%S")
            model_path = self.model_dir / f"local-run-{date}"
            makedirs(str(model_path))

        # get observations and rewards for task inference
        if self.domain_name == "point_mass_maze":
            self.goal_states = {}

            goal_state = point_mass_maze_goals[task[0]]
            self.goal_states[task[0]] = torch.tensor(
                goal_state, dtype=torch.float32, device=self.device
            ).unsqueeze(0)
        else:
            (
                self.observations_z,
                self.rewards_z,
            ) = replay_buffer.sample_task_inference_transitions(
                inference_steps=self.z_inference_steps
            )

        # get initial eval metrics
        logger.info("Getting init performance.")
        eval_metrics = self.eval(agent=agent, tasks=task)
        init_performance = eval_metrics["eval/task_reward_iqm"]
        best_mean_task_reward = -np.inf

        logger.info(f"Online finetuning {agent.name} on {self.domain_name}-{task[0]}.")
        j = 0
        while (j < episodes):

            # interact with env
            timestep = self.env.reset()
            z = agent.sample_z(size=1)
            if crl_flag:
                c_z = agent.sample_z(size=1).squeeze(0)
                eta = abs(agent.sample_z(size=1)[:,:1]).squeeze(0)
            while not timestep.last():
                if crl_flag:
                    action, _ = agent.act(
                    timestep.observation["observations"],
                    task=z.squeeze(),
                    c_z=c_z,
                    eta=eta,
                    step=None,
                    sample=True,
                    )
                else:
                    action, _ = agent.act(
                        timestep.observation["observations"],
                        task=z.squeeze(),
                        step=None,
                        sample=True,
                    )

                observation = timestep.observation["observations"]
                timestep = self.env.step(action)
                reward = self.reward_functions[task[0]](self.env.physics)
                done = timestep.last()
                j += 1

                replay_buffer.add(
                    observation=observation,
                    action=action,
                    reward=reward,
                    next_observation=timestep.observation["observations"],
                    done=done,
                )

                # start learning once batch size is reached
                if j >= agent.batch_size:
                    batch = replay_buffer.sample(agent.batch_size)

                    train_metrics = agent.update(batch=batch, step=j)
                    
                else:
                    train_metrics = {}

                if j % self.eval_frequency == 0:
                    agent._name = "last"
                    agent.save(model_path)
    
                    eval_metrics = self.eval(agent=agent, tasks=task)
                    eval_metrics["eval/init_performance"] = init_performance

                    if eval_metrics["eval/task_reward_iqm"] > best_mean_task_reward:
                        logger.info(
                            f"Finetuned performance:"
                            f"{eval_metrics['eval/task_reward_iqm']:.1f} |"
                            f" Init performance:"
                            f"{eval_metrics['eval/init_performance']:.1f}"
                        )

                        best_mean_task_reward = eval_metrics["eval/task_reward_iqm"]
                        safe_best_model_fpath =  agent.save_frequently(model_path, "best_"+str(j)+"_"+time.strftime("%Y-%m-%d-%H-%M-%S"))
                        print("save safe_best_model_fpath", safe_best_model_fpath)
                    agent.save_frequently(model_path, "step_"+str(j)+"_"+time.strftime("%Y-%m-%d-%H-%M-%S"))
                    agent.train()
                else:
                    eval_metrics = {}

                metrics = {**train_metrics, **eval_metrics}

                if self.wandb_logging:
                    run.log(metrics)

        if self.wandb_logging:
            # save to wandb_logging
            run.finish()
    


class D4RLWorkspace:
    """
    Workspace for training agents on D4RL tasks.
    """

    def __init__(
        self,
        env,
        domain_name: str,
        learning_steps: int,
        model_dir: Path,
        eval_frequency: int,
        eval_rollouts: int,
        wandb_logging: bool,
        device: torch.device,
        wandb_project: str,
        wandb_entity: str,
        z_inference_steps: Optional[int] = None,  # FB only
    ):
        self.env = env
        self.domain_name = domain_name
        self.learning_steps = learning_steps
        self.model_dir = model_dir
        self.eval_frequency = eval_frequency
        self.eval_rollouts = eval_rollouts
        self.wandb_logging = wandb_logging
        self.device = device
        self.wandb_project = wandb_project
        self.wandb_entity = wandb_entity
        self.z_inference_steps = z_inference_steps
        self.ref_max_score = {
            "walker": 4592.3,
            "cheetah": 12135.0,
        }
        self.ref_min_score = {
            "cheetah": -280.178953,
            "walker": 1.629008,
        }

    def train(
        self,
        agent: Union[FB, CFB, SF],
        agent_config: Dict,
        replay_buffer: D4RLReplayBuffer,
    ) -> None:

        if self.wandb_logging:
            run = wandb.init(
                entity=self.wandb_entity,
                project=self.wandb_project,
                config=agent_config,
                tags=[agent.name, "D4RL"],
                reinit=True,
            )

        logger.info(f"Training {agent.name}.")
        best_mean_task_reward = -np.inf

        # sample set transitions for z inference
        if isinstance(agent, (FB, SF)):
            (
                self.goals_z,
                self.rewards_z,
            ) = replay_buffer.sample_task_inference_transitions(
                inference_steps=self.z_inference_steps,
            )

        for i in tqdm(range(self.learning_steps + 1)):

            batch = replay_buffer.sample(agent.batch_size)
            train_metrics = agent.update(batch=batch, step=i)

            eval_metrics = {}

            if i % self.eval_frequency == 0:
                eval_metrics = self.eval(agent=agent)

                if eval_metrics["eval/score"] > best_mean_task_reward:
                    new_best_mean_task_reward = eval_metrics["eval/score"]
                    logger.info(
                        f"New max IQM task reward: {best_mean_task_reward:.3f} -> "
                        f"{new_best_mean_task_reward:.3f}."
                    )

                    best_mean_task_reward = new_best_mean_task_reward

                agent.train()

            metrics = {**train_metrics, **eval_metrics}

            if self.wandb_logging:
                run.log(metrics)

        if self.wandb_logging:
            run.finish()

    def eval(self, agent: Union[FB, CFB, SF]):
        """
        Evals agent.
        """

        logger.info(f"Evaluating {agent.name}.")

        if isinstance(agent, (FB, SF)):
            z = agent.infer_z(self.goals_z, self.rewards_z)

        eval_rewards = np.zeros(self.eval_rollouts)

        for i in tqdm(range(self.eval_rollouts), desc="eval rollouts"):

            observation = self.env.reset()
            terminated = False
            rollout_reward = 0.0

            while not terminated:
                if isinstance(agent, (FB, SF)):
                    action, _ = agent.act(
                        observation=observation, task=z, sample=False, step=None
                    )
                else:
                    action = agent.act(observation=observation, sample=False, step=None)
                observation, reward, terminated, _ = self.env.step(action)
                rollout_reward += reward

            eval_rewards[i] = rollout_reward

        eval_rewards = self._get_normalised_score(eval_rewards)
        metrics = {"eval/score": float(stats.trim_mean(eval_rewards, 0.25))}

        return metrics

    def _get_normalised_score(self, score: np.ndarray):
        return (
            (score - self.ref_min_score[self.domain_name])
            / (
                self.ref_max_score[self.domain_name]
                - self.ref_min_score[self.domain_name]
            )
            * 100
        )
